Skip to content

[mlir][vector] Refactor WarpOpScfForOp to support unused or swapped forOp results. #147620

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 15 commits into from
Jul 11, 2025

Conversation

charithaintc
Copy link
Contributor

Current implementation generates incorrect code or crashes in the following valid cases.

  1. At least one of the for op results are not yielded by the warpOp.
    Example:
%0 = gpu.warp_execute_on_lane_0(%arg0)[32] -> (vector<4xf32>) {
    ....
    %3:2 = scf.for %arg3 = %c0 to %c128 step %c1 iter_args(%arg4 = %ini, %arg5 = %ini1) -> (vector<128xf32>, vector<128xf32>) {
      
      %1  = ...
      %acc = ....
      scf.yield %acc, %1 : vector<128xf32>, vector<128xf32>
    }
    gpu.yield %3#0 : vector<128xf32> // %3#1 is not used but can not be removed as dead code (loop carried).
  }
  "some_use"(%0) : (vector<4xf32>) -> ()
  return
  1. Enclosing warpOp yields the forOp results in different order compared to the forOp results.
    Example:
  %0:3 = gpu.warp_execute_on_lane_0(%arg0)[32] -> (vector<4xf32>, vector<4xf32>, vector<8xf32>) {
    ....
    %3:3 = scf.for %arg3 = %c0 to %c128 step %c1 iter_args(%arg4 = %ini1, %arg5 = %ini2, %arg6 = %ini3) -> (vector<256xf32>, vector<128xf32>, vector<128xf32>) {
      .....
      scf.yield %acc1, %acc2, %acc3 : vector<256xf32>, vector<128xf32>, vector<128xf32>
    }
    gpu.yield %3#2, %3#1, %3#0 : vector<128xf32>, vector<128xf32>, vector<256xf32> // swapped order
  }
  "some_use_1"(%0#0) : (vector<4xf32>) -> ()
  "some_use_2"(%0#1) : (vector<4xf32>) -> ()
  "some_use_3"(%0#2) : (vector<8xf32>) -> ()

@llvmbot
Copy link
Member

llvmbot commented Jul 9, 2025

@llvm/pr-subscribers-mlir-gpu
@llvm/pr-subscribers-mlir-vector

@llvm/pr-subscribers-mlir

Author: Charitha Saumya (charithaintc)

Changes

Current implementation generates incorrect code or crashes in the following valid cases.

  1. At least one of the for op results are not yielded by the warpOp.
    Example:
%0 = gpu.warp_execute_on_lane_0(%arg0)[32] -&gt; (vector&lt;4xf32&gt;) {
    ....
    %3:2 = scf.for %arg3 = %c0 to %c128 step %c1 iter_args(%arg4 = %ini, %arg5 = %ini1) -&gt; (vector&lt;128xf32&gt;, vector&lt;128xf32&gt;) {
      
      %1  = ...
      %acc = ....
      scf.yield %acc, %1 : vector&lt;128xf32&gt;, vector&lt;128xf32&gt;
    }
    gpu.yield %3#<!-- -->0 : vector&lt;128xf32&gt; // %3#<!-- -->1 is not used but can not be removed as dead code (loop carried).
  }
  "some_use"(%0) : (vector&lt;4xf32&gt;) -&gt; ()
  return
  1. Enclosing warpOp yields the forOp results in different order compared to the forOp results.
    Example:
  %0:3 = gpu.warp_execute_on_lane_0(%arg0)[32] -&gt; (vector&lt;4xf32&gt;, vector&lt;4xf32&gt;, vector&lt;8xf32&gt;) {
    ....
    %3:3 = scf.for %arg3 = %c0 to %c128 step %c1 iter_args(%arg4 = %ini1, %arg5 = %ini2, %arg6 = %ini3) -&gt; (vector&lt;256xf32&gt;, vector&lt;128xf32&gt;, vector&lt;128xf32&gt;) {
      .....
      scf.yield %acc1, %acc2, %acc3 : vector&lt;256xf32&gt;, vector&lt;128xf32&gt;, vector&lt;128xf32&gt;
    }
    gpu.yield %3#<!-- -->2, %3#<!-- -->1, %3#<!-- -->0 : vector&lt;128xf32&gt;, vector&lt;128xf32&gt;, vector&lt;256xf32&gt; // swapped order
  }
  "some_use_1"(%0#<!-- -->0) : (vector&lt;4xf32&gt;) -&gt; ()
  "some_use_2"(%0#<!-- -->1) : (vector&lt;4xf32&gt;) -&gt; ()
  "some_use_3"(%0#<!-- -->2) : (vector&lt;8xf32&gt;) -&gt; ()


Full diff: https://github.com/llvm/llvm-project/pull/147620.diff

2 Files Affected:

  • (modified) mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp (+124-47)
  • (modified) mlir/test/Dialect/Vector/vector-warp-distribute.mlir (+79)
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
index fb99e22c77ea0..3ce134fe5f3ce 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
@@ -17,8 +17,12 @@
 #include "mlir/IR/AffineExpr.h"
 #include "mlir/IR/Attributes.h"
 #include "mlir/IR/BuiltinTypes.h"
+#include "mlir/IR/Value.h"
 #include "mlir/Interfaces/SideEffectInterfaces.h"
+#include "mlir/Support/LLVM.h"
 #include "mlir/Transforms/RegionUtils.h"
+#include "llvm/ADT/DenseMap.h"
+#include "llvm/ADT/STLExtras.h"
 #include "llvm/ADT/SetVector.h"
 #include "llvm/ADT/SmallVectorExtras.h"
 #include "llvm/Support/FormatVariadic.h"
@@ -1745,19 +1749,18 @@ struct WarpOpScfForOp : public WarpDistributionPattern {
       : WarpDistributionPattern(ctx, b), distributionMapFn(std::move(fn)) {}
   LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
                                 PatternRewriter &rewriter) const override {
-    auto yield = cast<gpu::YieldOp>(
+    auto newWarpOpYield = cast<gpu::YieldOp>(
         warpOp.getBodyRegion().getBlocks().begin()->getTerminator());
     // Only pick up forOp if it is the last op in the region.
-    Operation *lastNode = yield->getPrevNode();
+    Operation *lastNode = newWarpOpYield->getPrevNode();
     auto forOp = dyn_cast_or_null<scf::ForOp>(lastNode);
     if (!forOp)
       return failure();
     // Collect Values that come from the warp op but are outside the forOp.
-    // Those Value needs to be returned by the original warpOp and passed to
-    // the new op.
+    // Those Value needs to be returned by the new warp op.
     llvm::SmallSetVector<Value, 32> escapingValues;
-    SmallVector<Type> inputTypes;
-    SmallVector<Type> distTypes;
+    SmallVector<Type> escapingValueInputTypes;
+    SmallVector<Type> escapingValuedistTypes;
     mlir::visitUsedValuesDefinedAbove(
         forOp.getBodyRegion(), [&](OpOperand *operand) {
           Operation *parent = operand->get().getParentRegion()->getParentOp();
@@ -1769,81 +1772,155 @@ struct WarpOpScfForOp : public WarpDistributionPattern {
               AffineMap map = distributionMapFn(operand->get());
               distType = getDistributedType(vecType, map, warpOp.getWarpSize());
             }
-            inputTypes.push_back(operand->get().getType());
-            distTypes.push_back(distType);
+            escapingValueInputTypes.push_back(operand->get().getType());
+            escapingValuedistTypes.push_back(distType);
           }
         });
 
-    if (llvm::is_contained(distTypes, Type{}))
+    if (llvm::is_contained(escapingValuedistTypes, Type{}))
       return failure();
+    // Warp op can yield two types of values:
+    // 1. Values that are not results of the forOp:
+    //    These values must also be yielded by the new warp op. Also, we need to
+    //    record the index mapping for these values to replace them later.
+    // 2. Values that are results of the forOp:
+    //    In this case, we record the index mapping between the warp op result
+    //    index and matching forOp result index.
+    SmallVector<Value> nonForYieldedValues;
+    SmallVector<unsigned> nonForResultIndices;
+    DenseMap<unsigned, unsigned> forResultMapping;
+    for (OpOperand &yieldOperand : newWarpOpYield->getOpOperands()) {
+      // Yielded value is not a result of the forOp.
+      if (yieldOperand.get().getDefiningOp() != forOp.getOperation()) {
+        nonForYieldedValues.push_back(yieldOperand.get());
+        nonForResultIndices.push_back(yieldOperand.getOperandNumber());
+        continue;
+      }
+      OpResult forResult = cast<OpResult>(yieldOperand.get());
+      forResultMapping[yieldOperand.getOperandNumber()] =
+          forResult.getResultNumber();
+    }
 
-    SmallVector<size_t> newRetIndices;
-    WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
-        rewriter, warpOp, escapingValues.getArrayRef(), distTypes,
-        newRetIndices);
-    yield = cast<gpu::YieldOp>(
+    // Newly created warp op will yield values in following order:
+    // 1. All init args of the forOp.
+    // 2. All escaping values.
+    // 3. All non-for yielded values.
+    SmallVector<Value> newWarpOpYieldValues;
+    SmallVector<Type> newWarpOpDistTypes;
+    for (auto [i, initArg] : llvm::enumerate(forOp.getInitArgs())) {
+      newWarpOpYieldValues.push_back(initArg);
+      // Compute the distributed type for this init arg.
+      Type distType = initArg.getType();
+      if (auto vecType = dyn_cast<VectorType>(distType)) {
+        AffineMap map = distributionMapFn(initArg);
+        distType = getDistributedType(vecType, map, warpOp.getWarpSize());
+      }
+      newWarpOpDistTypes.push_back(distType);
+    }
+    // Insert escaping values and their distributed types.
+    newWarpOpYieldValues.insert(newWarpOpYieldValues.end(),
+                                escapingValues.begin(), escapingValues.end());
+    newWarpOpDistTypes.insert(newWarpOpDistTypes.end(),
+                              escapingValuedistTypes.begin(),
+                              escapingValuedistTypes.end());
+    // Next, we insert all non-for yielded values and their distributed types.
+    // We also create a mapping between the non-for yielded value index and the
+    // corresponding new warp op yield value index (needed to update users
+    // later).
+    DenseMap<unsigned, unsigned> warpResultMapping;
+    for (auto [i, v] : llvm::enumerate(nonForYieldedValues)) {
+      warpResultMapping[nonForResultIndices[i]] = newWarpOpYieldValues.size();
+      newWarpOpYieldValues.push_back(v);
+      newWarpOpDistTypes.push_back(
+          warpOp.getResult(nonForResultIndices[i]).getType());
+    }
+    // Create the new warp op with the updated yield values and types.
+    WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndReplaceReturns(
+        rewriter, warpOp, newWarpOpYieldValues, newWarpOpDistTypes);
+    newWarpOpYield = cast<gpu::YieldOp>(
         newWarpOp.getBodyRegion().getBlocks().begin()->getTerminator());
 
-    SmallVector<Value> newOperands;
-    SmallVector<unsigned> resultIdx;
-    // Collect all the outputs coming from the forOp.
-    for (OpOperand &yieldOperand : yield->getOpOperands()) {
-      if (yieldOperand.get().getDefiningOp() != forOp.getOperation())
-        continue;
-      auto forResult = cast<OpResult>(yieldOperand.get());
-      newOperands.push_back(
-          newWarpOp.getResult(yieldOperand.getOperandNumber()));
-      yieldOperand.set(forOp.getInitArgs()[forResult.getResultNumber()]);
-      resultIdx.push_back(yieldOperand.getOperandNumber());
-    }
+    // Next, we create a new for op with the init args yielded by the new
+    // warp op.
+    unsigned escapingValuesStartIdx =
+        forOp.getInitArgs().size(); // ForOp init args are positioned before
+                                    // escaping values in the new warp op.
+    SmallVector<Value> newForOpOperands;
+    for (size_t i = 0; i < escapingValuesStartIdx; ++i)
+      newForOpOperands.push_back(newWarpOp.getResult(i));
 
+    // Create a new for op outside the new warp op region.
     OpBuilder::InsertionGuard g(rewriter);
     rewriter.setInsertionPointAfter(newWarpOp);
-
-    // Create a new for op outside the region with a WarpExecuteOnLane0Op
-    // region inside.
     auto newForOp = rewriter.create<scf::ForOp>(
         forOp.getLoc(), forOp.getLowerBound(), forOp.getUpperBound(),
-        forOp.getStep(), newOperands);
+        forOp.getStep(), newForOpOperands);
+    // Next, we insert a new warp op (called inner warp op) inside the
+    // newly created for op. This warp op will contain all ops that were
+    // contained within the original for op body.
     rewriter.setInsertionPointToStart(newForOp.getBody());
 
-    SmallVector<Value> warpInput(newForOp.getRegionIterArgs().begin(),
-                                 newForOp.getRegionIterArgs().end());
-    SmallVector<Type> warpInputType(forOp.getResultTypes().begin(),
-                                    forOp.getResultTypes().end());
+    SmallVector<Value> innerWarpInput(newForOp.getRegionIterArgs().begin(),
+                                      newForOp.getRegionIterArgs().end());
+    SmallVector<Type> innerWarpInputType(forOp.getResultTypes().begin(),
+                                         forOp.getResultTypes().end());
+    // Escaping values are forwarded to the inner warp op as its (additional)
+    // arguments. We keep track of the mapping between these values and their
+    // argument index in the inner warp op (to replcace uses later).
     llvm::SmallDenseMap<Value, int64_t> argIndexMapping;
-    for (auto [i, retIdx] : llvm::enumerate(newRetIndices)) {
-      warpInput.push_back(newWarpOp.getResult(retIdx));
-      argIndexMapping[escapingValues[i]] = warpInputType.size();
-      warpInputType.push_back(inputTypes[i]);
+    for (size_t i = escapingValuesStartIdx;
+         i < escapingValuesStartIdx + escapingValues.size(); ++i) {
+      innerWarpInput.push_back(newWarpOp.getResult(i));
+      argIndexMapping[escapingValues[i - escapingValuesStartIdx]] =
+          innerWarpInputType.size();
+      innerWarpInputType.push_back(
+          escapingValueInputTypes[i - escapingValuesStartIdx]);
     }
+    // Create the inner warp op with the new input values and types.
     auto innerWarp = rewriter.create<WarpExecuteOnLane0Op>(
         newWarpOp.getLoc(), newForOp.getResultTypes(), newWarpOp.getLaneid(),
-        newWarpOp.getWarpSize(), warpInput, warpInputType);
+        newWarpOp.getWarpSize(), innerWarpInput, innerWarpInputType);
 
+    // Inline the for op body into the inner warp op body.
     SmallVector<Value> argMapping;
     argMapping.push_back(newForOp.getInductionVar());
-    for (Value args : innerWarp.getBody()->getArguments()) {
+    for (Value args : innerWarp.getBody()->getArguments())
       argMapping.push_back(args);
-    }
+
     argMapping.resize(forOp.getBody()->getNumArguments());
     SmallVector<Value> yieldOperands;
     for (Value operand : forOp.getBody()->getTerminator()->getOperands())
       yieldOperands.push_back(operand);
+
     rewriter.eraseOp(forOp.getBody()->getTerminator());
     rewriter.mergeBlocks(forOp.getBody(), innerWarp.getBody(), argMapping);
+
+    // Insert a gpu yieldOp at the end of the inner warp op body that yields
+    // original forOp results.
     rewriter.setInsertionPointToEnd(innerWarp.getBody());
     rewriter.create<gpu::YieldOp>(innerWarp.getLoc(), yieldOperands);
     rewriter.setInsertionPointAfter(innerWarp);
+    // Insert a scf.yield op at the end of the new for op body that yields
+    // the inner warp op results.
     if (!innerWarp.getResults().empty())
       rewriter.create<scf::YieldOp>(forOp.getLoc(), innerWarp.getResults());
+
+    // Update the users of original warp op results that were coming from the
+    // original forOp to the corresponding new forOp result.
+    for (auto [origIdx, newIdx] : forResultMapping)
+      rewriter.replaceAllUsesExcept(warpOp.getResult(origIdx),
+                                    newForOp.getResult(newIdx), newForOp);
+    // Similarly, update any users of the warp op results that were not
+    // results of the forOp.
+    for (auto [origIdx, newIdx] : warpResultMapping)
+      rewriter.replaceAllUsesWith(warpOp.getResult(origIdx),
+                                  newWarpOp.getResult(newIdx));
+    // Remove the original warp op and for op, they should not have any uses
+    // at this point.
     rewriter.eraseOp(forOp);
-    // Replace the warpOp result coming from the original ForOp.
-    for (const auto &res : llvm::enumerate(resultIdx)) {
-      rewriter.replaceAllUsesWith(newWarpOp.getResult(res.value()),
-                                  newForOp.getResult(res.index()));
-      newForOp->setOperand(res.index() + 3, newWarpOp.getResult(res.value()));
-    }
+    rewriter.eraseOp(warpOp);
+    // Update any users of escaping values that were forwarded to the
+    // inner warp op. These values are now arguments of the inner warp op.
     newForOp.walk([&](Operation *op) {
       for (OpOperand &operand : op->getOpOperands()) {
         auto it = argIndexMapping.find(operand.get());
diff --git a/mlir/test/Dialect/Vector/vector-warp-distribute.mlir b/mlir/test/Dialect/Vector/vector-warp-distribute.mlir
index 1161dbd4b2166..56fa38ce5a3e8 100644
--- a/mlir/test/Dialect/Vector/vector-warp-distribute.mlir
+++ b/mlir/test/Dialect/Vector/vector-warp-distribute.mlir
@@ -584,6 +584,85 @@ func.func @warp_scf_for_multiple_yield(%arg0: index, %arg1: memref<?xf32>, %arg2
   return
 }
 
+// -----
+// CHECK-PROP-LABEL: func.func @warp_scf_for_unused_for_result(
+//       CHECK-PROP: %[[W0:.*]]:2 = gpu.warp_execute_on_lane_0(%{{.*}})[32] -> (vector<4xf32>, vector<4xf32>) {
+//       CHECK-PROP:  %[[INI0:.*]] = "some_def"() : () -> vector<128xf32>
+//       CHECK-PROP:  %[[INI1:.*]] = "some_def"() : () -> vector<128xf32>
+//       CHECK-PROP:  gpu.yield %[[INI0]], %[[INI1]] : vector<128xf32>, vector<128xf32>
+//       CHECK-PROP: }
+//       CHECK-PROP: %[[F:.*]]:2 = scf.for %{{.*}} iter_args(%{{.*}} = %[[W0]]#0, %{{.*}} = %[[W0]]#1) -> (vector<4xf32>, vector<4xf32>) {
+//       CHECK-PROP:  %[[W1:.*]]:2 = gpu.warp_execute_on_lane_0(%{{.*}})[32] args(%{{.*}} : vector<4xf32>, vector<4xf32>) -> (vector<4xf32>, vector<4xf32>) {
+//       CHECK-PROP:    %[[ACC0:.*]] = "some_def"(%{{.*}}) : (vector<128xf32>, index) -> vector<128xf32>
+//       CHECK-PROP:    %[[ACC1:.*]] = "some_def"(%{{.*}}) : (index, vector<128xf32>, vector<128xf32>) -> vector<128xf32>
+//       CHECK-PROP:    gpu.yield %[[ACC1]], %[[ACC0]] : vector<128xf32>, vector<128xf32>
+//       CHECK-PROP:  }
+//       CHECK-PROP:  scf.yield %[[W1]]#0, %[[W1]]#1 : vector<4xf32>, vector<4xf32>
+//       CHECK-PROP: }
+//       CHECK-PROP: "some_use"(%[[F]]#0) : (vector<4xf32>) -> ()
+func.func @warp_scf_for_unused_for_result(%arg0: index) {
+  %c128 = arith.constant 128 : index
+  %c1 = arith.constant 1 : index
+  %c0 = arith.constant 0 : index
+  %0 = gpu.warp_execute_on_lane_0(%arg0)[32] -> (vector<4xf32>) {
+    %ini = "some_def"() : () -> (vector<128xf32>)
+    %ini1 = "some_def"() : () -> (vector<128xf32>)
+    %3:2 = scf.for %arg3 = %c0 to %c128 step %c1 iter_args(%arg4 = %ini, %arg5 = %ini1) -> (vector<128xf32>, vector<128xf32>) {
+      %add = arith.addi %arg3, %c1 : index
+      %1  = "some_def"(%arg5, %add) : (vector<128xf32>, index) -> (vector<128xf32>)
+      %acc = "some_def"(%add, %arg4, %1) : (index, vector<128xf32>, vector<128xf32>) -> (vector<128xf32>)
+      scf.yield %acc, %1 : vector<128xf32>, vector<128xf32>
+    }
+    gpu.yield %3#0 : vector<128xf32>
+  }
+  "some_use"(%0) : (vector<4xf32>) -> ()
+  return
+}
+
+// -----
+// CHECK-PROP-LABEL: func.func @warp_scf_for_swapped_for_results(
+//       CHECK-PROP:  %[[W0:.*]]:3 = gpu.warp_execute_on_lane_0(%{{.*}})[32] -> (vector<8xf32>, vector<4xf32>, vector<4xf32>) {
+//  CHECK-PROP-NEXT:    %[[INI0:.*]] = "some_def"() : () -> vector<256xf32>
+//  CHECK-PROP-NEXT:    %[[INI1:.*]] = "some_def"() : () -> vector<128xf32>
+//  CHECK-PROP-NEXT:    %[[INI2:.*]] = "some_def"() : () -> vector<128xf32>
+//  CHECK-PROP-NEXT:    gpu.yield %[[INI0]], %[[INI1]], %[[INI2]] : vector<256xf32>, vector<128xf32>, vector<128xf32>
+//  CHECK-PROP-NEXT:  }
+//  CHECK-PROP-NEXT:  %[[F0:.*]]:3 = scf.for {{.*}} iter_args(%{{.*}} = %[[W0]]#0, %{{.*}} = %[[W0]]#1, %{{.*}} = %[[W0]]#2) -> (vector<8xf32>, vector<4xf32>, vector<4xf32>) {
+//  CHECK-PROP-NEXT:    %[[W1:.*]]:3 = gpu.warp_execute_on_lane_0(%{{.*}})[32] args(%{{.*}} :
+//  CHECK-PROP-SAME:        vector<8xf32>, vector<4xf32>, vector<4xf32>) -> (vector<8xf32>, vector<4xf32>, vector<4xf32>) {
+//  CHECK-PROP-NEXT:      ^bb0(%{{.*}}: vector<256xf32>, %{{.*}}: vector<128xf32>, %{{.*}}: vector<128xf32>):
+//  CHECK-PROP-NEXT:        %[[T3:.*]] = "some_def_1"(%{{.*}}) : (vector<256xf32>) -> vector<256xf32>
+//  CHECK-PROP-NEXT:        %[[T4:.*]] = "some_def_2"(%{{.*}}) : (vector<128xf32>) -> vector<128xf32>
+//  CHECK-PROP-NEXT:        %[[T5:.*]] = "some_def_3"(%{{.*}}) : (vector<128xf32>) -> vector<128xf32>
+//  CHECK-PROP-NEXT:        gpu.yield %[[T3]], %[[T4]], %[[T5]] : vector<256xf32>, vector<128xf32>, vector<128xf32>
+//  CHECK-PROP-NEXT:    }
+//  CHECK-PROP-NEXT:    scf.yield %[[W1]]#0, %[[W1]]#1, %[[W1]]#2 : vector<8xf32>, vector<4xf32>, vector<4xf32>
+//  CHECK-PROP-NEXT:  }
+//  CHECK-PROP-NEXT:  "some_use_1"(%[[F0]]#2) : (vector<4xf32>) -> ()
+//  CHECK-PROP-NEXT:  "some_use_2"(%[[F0]]#1) : (vector<4xf32>) -> ()
+//  CHECK-PROP-NEXT:  "some_use_3"(%[[F0]]#0) : (vector<8xf32>) -> ()
+func.func @warp_scf_for_swapped_for_results(%arg0: index) {
+  %c128 = arith.constant 128 : index
+  %c1 = arith.constant 1 : index
+  %c0 = arith.constant 0 : index
+  %0:3 = gpu.warp_execute_on_lane_0(%arg0)[32] -> (vector<4xf32>, vector<4xf32>, vector<8xf32>) {
+    %ini1 = "some_def"() : () -> (vector<256xf32>)
+    %ini2 = "some_def"() : () -> (vector<128xf32>)
+    %ini3 = "some_def"() : () -> (vector<128xf32>)
+    %3:3 = scf.for %arg3 = %c0 to %c128 step %c1 iter_args(%arg4 = %ini1, %arg5 = %ini2, %arg6 = %ini3) -> (vector<256xf32>, vector<128xf32>, vector<128xf32>) {
+      %acc1 = "some_def_1"(%arg4) : (vector<256xf32>) -> (vector<256xf32>)
+      %acc2 = "some_def_2"(%arg5) : (vector<128xf32>) -> (vector<128xf32>)
+      %acc3 = "some_def_3"(%arg6) : (vector<128xf32>) -> (vector<128xf32>)
+      scf.yield %acc1, %acc2, %acc3 : vector<256xf32>, vector<128xf32>, vector<128xf32>
+    }
+    gpu.yield %3#2, %3#1, %3#0 : vector<128xf32>, vector<128xf32>, vector<256xf32>
+  }
+  "some_use_1"(%0#0) : (vector<4xf32>) -> ()
+  "some_use_2"(%0#1) : (vector<4xf32>) -> ()
+  "some_use_3"(%0#2) : (vector<8xf32>) -> ()
+  return
+}
+
 // -----
 
 // CHECK-PROP-LABEL: func @vector_reduction(

Copy link
Contributor

@kurapov-peter kurapov-peter left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The transformation looks correct to me, nice! Wouldn't we need basically the same thing for other scf ops? Seems like this can be generalized.

// for now.
if (!layout)
return AffineMap::getMultiDimMapWithTargets(
vecRank, {static_cast<unsigned int>(vecRank - 1)}, val.getContext());
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This assumes 2d, right?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes. At SG level layout is currently only 2d. But upstream distribution makes no such assumption. Only assumption there is only 1 dim is distributed. We need to add more support there in future.

@charithaintc
Copy link
Contributor Author

The transformation looks correct to me, nice! Wouldn't we need basically the same thing for other scf ops? Seems like this can be generalized.

Thanks for the review Peter. Currently we don't have support for other scf ops (like scf.if, scf.while) which we may need later. I will think about generalization steps once we have more use cases for these ops. But appreciate your suggestion. This pattern is very complex to begin with, so some kind of abstraction to deal with these moving parts will def help.

Copy link
Contributor

@kurapov-peter kurapov-peter left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sounds good!

Copy link
Contributor

@chencha3 chencha3 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

Copy link
Contributor

@Jianhui-Li Jianhui-Li left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM w/ some minor comments.

@charithaintc charithaintc merged commit 3092b76 into llvm:main Jul 11, 2025
9 checks passed
@llvm-ci
Copy link
Collaborator

llvm-ci commented Jul 11, 2025

LLVM Buildbot has detected a new failure on builder mlir-nvidia-gcc7 running on mlir-nvidia while building mlir at step 7 "test-build-check-mlir-build-only-check-mlir".

Full details are available at: https://lab.llvm.org/buildbot/#/builders/116/builds/15477

Here is the relevant piece of the build log for the reference
Step 7 (test-build-check-mlir-build-only-check-mlir) failure: test (failure)
******************** TEST 'MLIR :: Integration/Dialect/Vector/GPU/CUDA/test-reduction-distribute.mlir' FAILED ********************
Exit Code: 2

Command Output (stdout):
--
# RUN: at line 1
/vol/worker/mlir-nvidia/mlir-nvidia-gcc7/llvm.obj/bin/mlir-opt /vol/worker/mlir-nvidia/mlir-nvidia-gcc7/llvm.src/mlir/test/Integration/Dialect/Vector/GPU/CUDA/test-reduction-distribute.mlir -test-vector-warp-distribute="hoist-uniform distribute-transfer-write propagate-distribution" -canonicalize | /vol/worker/mlir-nvidia/mlir-nvidia-gcc7/llvm.obj/bin/mlir-opt -test-vector-warp-distribute=rewrite-warp-ops-to-scf-if | /vol/worker/mlir-nvidia/mlir-nvidia-gcc7/llvm.obj/bin/mlir-opt -lower-affine -convert-vector-to-scf -convert-scf-to-cf -convert-vector-to-llvm   -convert-arith-to-llvm -gpu-lower-to-nvvm-pipeline |  /vol/worker/mlir-nvidia/mlir-nvidia-gcc7/llvm.obj/bin/mlir-runner -e main -entry-point-result=void    -shared-libs=/vol/worker/mlir-nvidia/mlir-nvidia-gcc7/llvm.obj/lib/libmlir_cuda_runtime.so    -shared-libs=/vol/worker/mlir-nvidia/mlir-nvidia-gcc7/llvm.obj/lib/libmlir_c_runner_utils.so    -shared-libs=/vol/worker/mlir-nvidia/mlir-nvidia-gcc7/llvm.obj/lib/libmlir_runner_utils.so |  /vol/worker/mlir-nvidia/mlir-nvidia-gcc7/llvm.obj/bin/FileCheck /vol/worker/mlir-nvidia/mlir-nvidia-gcc7/llvm.src/mlir/test/Integration/Dialect/Vector/GPU/CUDA/test-reduction-distribute.mlir
# executed command: /vol/worker/mlir-nvidia/mlir-nvidia-gcc7/llvm.obj/bin/mlir-opt /vol/worker/mlir-nvidia/mlir-nvidia-gcc7/llvm.src/mlir/test/Integration/Dialect/Vector/GPU/CUDA/test-reduction-distribute.mlir '-test-vector-warp-distribute=hoist-uniform distribute-transfer-write propagate-distribution' -canonicalize
# .---command stderr------------
# | mlir-opt: /vol/worker/mlir-nvidia/mlir-nvidia-gcc7/llvm.src/mlir/lib/IR/TypeRange.cpp:20: mlir::TypeRange::TypeRange(llvm::ArrayRef<mlir::Type>): Assertion `llvm::all_of(types, [](Type t) { return t; }) && "attempting to construct a TypeRange with null types"' failed.
# | PLEASE submit a bug report to https://github.com/llvm/llvm-project/issues/ and include the crash backtrace.
# | Stack dump without symbol names (ensure you have llvm-symbolizer in your PATH or set the environment var `LLVM_SYMBOLIZER_PATH` to point to it):
# | 0  mlir-opt  0x000057423066abeb llvm::sys::PrintStackTrace(llvm::raw_ostream&, int) + 59
# | 1  mlir-opt  0x0000574230667524
# | 2  libc.so.6 0x00007af18a944520
# | 3  libc.so.6 0x00007af18a9989fc pthread_kill + 300
# | 4  libc.so.6 0x00007af18a944476 raise + 22
# | 5  libc.so.6 0x00007af18a92a7f3 abort + 211
# | 6  libc.so.6 0x00007af18a92a71b
# | 7  libc.so.6 0x00007af18a93be96
# | 8  mlir-opt  0x0000574233bfb671 mlir::TypeRange::TypeRange(llvm::ArrayRef<mlir::Type>) + 353
# | 9  mlir-opt  0x000057423324311e
# | 10 mlir-opt  0x0000574233220d17
# | 11 mlir-opt  0x0000574236a942db mlir::PatternApplicator::matchAndRewrite(mlir::Operation*, mlir::PatternRewriter&, llvm::function_ref<bool (mlir::Pattern const&)>, llvm::function_ref<void (mlir::Pattern const&)>, llvm::function_ref<llvm::LogicalResult (mlir::Pattern const&)>) + 1947
# | 12 mlir-opt  0x0000574233a8f121
# | 13 mlir-opt  0x0000574233a91cfa mlir::applyPatternsGreedily(mlir::Region&, mlir::FrozenRewritePatternSet const&, mlir::GreedyRewriteConfig, bool*) + 3482
# | 14 mlir-opt  0x0000574234276f02
# | 15 mlir-opt  0x000057423427db7b
# | 16 mlir-opt  0x00005742339c0bee mlir::detail::OpToOpPassAdaptor::run(mlir::Pass*, mlir::Operation*, mlir::AnalysisManager, bool, unsigned int) + 1262
# | 17 mlir-opt  0x00005742339c0f4a mlir::detail::OpToOpPassAdaptor::runPipeline(mlir::OpPassManager&, mlir::Operation*, mlir::AnalysisManager, bool, unsigned int, mlir::PassInstrumentor*, mlir::PassInstrumentation::PipelineParentInfo const*) + 250
# | 18 mlir-opt  0x00005742339c1342
# | 19 mlir-opt  0x00005742339c14cf
# | 20 mlir-opt  0x00005742339b537d std::_Function_handler<std::unique_ptr<std::__future_base::_Result_base, std::__future_base::_Result_base::_Deleter> (), std::__future_base::_Task_setter<std::unique_ptr<std::__future_base::_Result<void>, std::__future_base::_Result_base::_Deleter>, std::thread::_Invoker<std::tuple<std::function<void ()>>>, void>>::_M_invoke(std::_Any_data const&) + 29
# | 21 mlir-opt  0x00005742339b5ef9
# | 22 libc.so.6 0x00007af18a99bee8
# | 23 mlir-opt  0x00005742339b55ec std::__future_base::_Deferred_state<std::thread::_Invoker<std::tuple<std::function<void ()>>>, void>::_M_complete_async() + 284
# | 24 mlir-opt  0x00005742339b70af std::_Function_handler<void (), std::shared_future<void> llvm::ThreadPoolInterface::asyncImpl<void>(std::function<void ()>, llvm::ThreadPoolTaskGroup*)::'lambda'()>::_M_invoke(std::_Any_data const&) + 31
# | 25 mlir-opt  0x0000574233bcacf5 llvm::StdThreadPool::processTasks(llvm::ThreadPoolTaskGroup*) + 645
# | 26 mlir-opt  0x0000574233bcbbf8
# | 27 libc.so.6 0x00007af18a996ac3
# | 28 libc.so.6 0x00007af18aa28850
# `-----------------------------
# error: command failed with exit status: -6
# executed command: /vol/worker/mlir-nvidia/mlir-nvidia-gcc7/llvm.obj/bin/mlir-opt -test-vector-warp-distribute=rewrite-warp-ops-to-scf-if
# executed command: /vol/worker/mlir-nvidia/mlir-nvidia-gcc7/llvm.obj/bin/mlir-opt -lower-affine -convert-vector-to-scf -convert-scf-to-cf -convert-vector-to-llvm -convert-arith-to-llvm -gpu-lower-to-nvvm-pipeline
# executed command: /vol/worker/mlir-nvidia/mlir-nvidia-gcc7/llvm.obj/bin/mlir-runner -e main -entry-point-result=void -shared-libs=/vol/worker/mlir-nvidia/mlir-nvidia-gcc7/llvm.obj/lib/libmlir_cuda_runtime.so -shared-libs=/vol/worker/mlir-nvidia/mlir-nvidia-gcc7/llvm.obj/lib/libmlir_c_runner_utils.so -shared-libs=/vol/worker/mlir-nvidia/mlir-nvidia-gcc7/llvm.obj/lib/libmlir_runner_utils.so
# .---command stderr------------
# | Error: entry point not found
# `-----------------------------
# error: command failed with exit status: 1
...

charithaintc added a commit that referenced this pull request Jul 11, 2025
llvm-sync bot pushed a commit to arm/arm-toolchain that referenced this pull request Jul 11, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

7 participants